PyTorch Lightningで書き換え
LightningModule
ネットワークの定義
訓練で呼ばれるメソッドの上書き
例:メトリクス
Trainer
LightningModuleとそれに渡すデータを扱う
個々のDataLoaderの例
データをまとめられる
3つを使うとコードが劇的にスッキリする
code:python
dm = MNISTDataModule()
model = Model()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
trainer.validate(datamodule=dm)
trainer.predict(datamodule=dm)